Our task is to predict flu infections based on the methods described here.

We used dataset Flu cases from 1997 through 2021 by week.

In [1]:
import datetime
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from scipy.optimize import least_squares
from scipy.signal import find_peaks

TODO: zmienić import na ten z ibm cloud

In [2]:
flu_data = pd.read_excel('flu_cases_1997_2021_raw_data.xlsx')
flu_data.head()
Out[2]:
YEAR WEEK Total_cases
0 1997 40 0
1 1997 41 11
2 1997 42 17
3 1997 43 8
4 1997 44 10
In [3]:
# we want to obtain date from columns YEAR and WEEK
flu_data['date'] = pd.to_datetime(flu_data.YEAR.astype(str), format='%Y') + \
             pd.to_timedelta(flu_data.WEEK.mul(7).astype(str) + ' days')
flu_data.head()
Out[3]:
YEAR WEEK Total_cases date
0 1997 40 0 1997-10-08
1 1997 41 11 1997-10-15
2 1997 42 17 1997-10-22
3 1997 43 8 1997-10-29
4 1997 44 10 1997-11-05
In [4]:
# we add additional column for day count
days = list(range(0, 7*len(flu_data), 7))
flu_data['days'] = days
flu_data.head()
Out[4]:
YEAR WEEK Total_cases date days
0 1997 40 0 1997-10-08 0
1 1997 41 11 1997-10-15 7
2 1997 42 17 1997-10-22 14
3 1997 43 8 1997-10-29 21
4 1997 44 10 1997-11-05 28

Flu cases from 1997 to 2021 - plot¶

In [5]:
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.show()

We can see that data since April 2020 differ from the previous years -- it is probably because of the COVID-19 pandemics. We've decided to delete that data from our dataset.

In [6]:
flu_data = flu_data[~ flu_data['YEAR'].isin([2020, 2021])]
In [7]:
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.show()

Finding local maxima¶

In [8]:
# peaks - indexes of peaks
peaks, _ = find_peaks(flu_data['Total_cases'], distance=24)
In [9]:
x = flu_data['Total_cases']
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.add_scatter(x=flu_data['date'].iloc[peaks], y=x[peaks], mode="markers", marker_symbol='x', marker=dict(size=8), name='local maxima')
fig.show()
In [10]:
# choosing dataframe only for peaks
peaks_flu = flu_data.iloc[peaks]
In [11]:
def PolyCoefficients(x, coeffs):
    o = len(coeffs)
    y = 0
    for i in reversed(list(range(o))):
        y += coeffs[i]*x**i
    y += coeffs[-1]
    return y
In [12]:
x = flu_data['days']
In [13]:
p_1 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=1)
np.polyval(p_1, list(peaks_flu['days'])[-1] + 365)
Out[13]:
3767.5588269873924
In [14]:
p_2 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=2)
p_3 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=3)
p_4 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=4)
p_5 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=5)
p_6 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=6)
In [15]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=3, cols=2)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_1), name='deg=1'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_2), name='deg=2'),
    row=1, col=2
)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_3), name='deg=3'),
    row=2, col=1
)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_4), name='deg=4'),
    row=2, col=2
)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_5), name='deg=5'),
    row=3, col=1
)

fig.add_trace(
    go.Scatter(x=x, y=PolyCoefficients(x, p_6), name='deg=6'),
    row=3, col=2
)


fig.update_layout(title_text="Polynomial plots for different degree of polynomial")
fig.show()

Final model¶

In [16]:
def KMcK(S1, I1, R1, n, b=0.434, a=0.4):
    S = [S1]
    I = [I1]
    R = [R1]
    for i in range(n):
        S.append(S[i] - b * I[i] * S[i])
        I.append(I[i] + b * I[i] * S[i] - a * I[i])
        R.append(R[i] + a * I[i])
    return S, I, R
In [17]:
KMcK(
    S1=1000,
    I1=11,
    R1=0,
    n=10,
)
Out[17]:
([1000,
  -3774.0,
  7826447.2296,
  26587005881167.25,
  3.0678105855204267e+26,
  4.0845744162657916e+52,
  7.240746702313656e+104,
  2.275393115826668e+209,
  inf,
  inf,
  nan],
 [11,
  4780.6,
  -7827352.869600001,
  -26587002751131.742,
  -3.06781058552032e+26,
  -4.0845744162657916e+52,
  -7.240746702313656e+104,
  -2.275393115826668e+209,
  -inf,
  nan,
  nan],
 [0,
  4.4,
  1916.6400000000003,
  -3129024.5078400006,
  -10634804229477.205,
  -1.2271242342082344e+26,
  -1.6338297665063168e+52,
  -2.8962986809254624e+104,
  -9.101572463306672e+208,
  -inf,
  nan])
In [ ]: